# =============================================================================
#  Analysis 3: This example script demonstrates drosophila tracking analysis using supplementary movie 1 from Amat et al. 2014
#
# [Amat et al., 2014]: Amat, F., Lemon, W., Mossing, D. P., McDole, K., Wan, Y., Branson, K., ... & Keller, P. J. (2014). Fast, accurate reconstruction of cell lineages from large-scale fluorescence microscopy data. Nature methods, 11(9), 951.
#
# =============================================================================
"""
1. Load and read the video using moviepy
"""
def read_movie(moviefile, resize=1.):
    """
    for more information on using the moviepy library
    https://zulko.github.io/moviepy/
    """
    from moviepy.editor import VideoFileClip, ImageClip
    from tqdm import tqdm 
    from skimage.transform import rescale
    import numpy as np 
    
    vidframes = []
    clip = VideoFileClip(moviefile)
    count = 0
    for frame in tqdm(clip.iter_frames()):
        vidframes.append(np.uint8(rescale(frame, 1./resize, preserve_range=True)))
        count+=1
    
    return np.array(vidframes)
    
moviefile = '../Data/Videos/nmeth.3036-sv1.avi'
movie = read_movie(moviefile, resize=4.)

n_frame, n_rows, n_cols, n_channels = movie.shape
print('Size of video: (%d,%d,%d,%d)' %(n_frame,n_rows,n_cols,n_channels))

"""
2.Motion Extraction
"""
from MOSES.Optical_Flow_Tracking.superpixel_track import compute_grayscale_vid_superpixel_tracks_FB
# motion extraction parameters. 
optical_flow_params = dict(pyr_scale=0.5, levels=5, winsize=21, iterations=5, poly_n=5, poly_sigma=1.2, flags=0)
# number of superpixels
n_spixels = 1000

# extract forwar and backward tracks
optflow, meantracks_F, meantracks_B = compute_grayscale_vid_superpixel_tracks_FB(movie[:,:,:,1], optical_flow_params, n_spixels, dense=True)

# plot all the tracks. 
import pylab as plt 
from MOSES.Visualisation_Tools.track_plotting import plot_tracks
fig, ax = plt.subplots()
ax.imshow(movie[0])
plot_tracks(meantracks_F, ax, color='r', lw=1.0, alpha=1)
plt.show()

# save the tracks 
import os 
fname = os.path.split(moviefile)[-1]
savetracksmat = ('meantracks-%d_' %(n_spixels) +fname).replace('.tif', '.mat')
spio.savemat(savetracksmat, {'meantracks':meantracks})

"""
3. Track clustering and visualization using the tracks
"""
import numpy as np
X = np.hstack([meantracks_F[:,:,0], meantracks_F[:,:,1]])
#X = (X[:,1:] - X[:,:-1]).astype(np.float) # uncomment to cluster instead on velocity

from sklearn.decomposition import PCA

pca_model = PCA(n_components = 3, whiten=True, random_state=0)
X_pca = pca_model.fit_transform(X)

from sklearn import mixture
n_clusters = 10
gmm = mixture.GaussianMixture(n_components=n_clusters, covariance_type='full', random_state=0)
gmm.fit(X_pca)

track_labels = gmm.predict(X_pca)

# plotting to visualise. 
from MOSES.Visualisation_Tools.track_plotting import plot_tracks
import seaborn as sns
# Generate colours for each unique cluster.
cluster_colors = sns.color_palette('Set1', n_clusters)

# overlay cluster tracks and clustered superpixels
fig, ax = plt.subplots(nrows=1,ncols=3, figsize=(15,15))
ax[0].imshow(movie[0]); ax[0].grid('off'); ax[0].axis('off'); ax[0].set_title('Initial Points')
ax[1].imshow(movie[-1]); ax[1].grid('off'); ax[1].axis('off'); ax[1].set_title('Final Points')
ax[2].imshow(movie[0]); ax[2].grid('off'); ax[2].axis('off'); ax[2].set_title('Clustered Tracks')
for ii, lab in enumerate(np.unique(track_labels)):
    # plot coloured initial points 
    ax[0].plot(meantracks_F[track_labels==lab,0,1], 
            meantracks_F[track_labels==lab,0,0], 'o', color=cluster_colors[ii], alpha=1)
    # plot coloured final points 
    ax[1].plot(meantracks_F[track_labels==lab,-1,1], 
            meantracks_F[track_labels==lab,-1,0], 'o', color=cluster_colors[ii], alpha=1)
    # plot coloured tracks
    plot_tracks(meantracks_F[track_labels==lab], ax[2], color=cluster_colors[ii], lw=1.0, alpha=0.7)
#fig.savefig(os.path.join(saveanalysisfolder, 'motion-clustered-tracks_drosophila.svg'), bbox_inches='tight', dpi=300)
#fig.savefig('motion-clustered-tracks_drosophila_reanalysis.svg', bbox_inches='tight', dpi=300)
plt.show()

"""
4. Motion Source Analysis
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import compute_motion_saliency_map
from skimage.exposure import equalize_hist
from skimage.filters import gaussian
# specify a large threshold to capture long-distances.
dist_thresh = 20
spixel_size = meantracks_F[1,0,1]-meantracks_F[1,0,0]

motion_saliency_F, motion_saliency_spatial_time_F = compute_motion_saliency_map(meantracks_F, dist_thresh=dist_thresh, 
                                                                  shape=movie.shape[1:-1], max_frame=None, filt=1, filt_size=spixel_size)
motion_saliency_B, motion_saliency_spatial_time_B = compute_motion_saliency_map(meantracks_B, dist_thresh=dist_thresh, 
                                                                  shape=movie.shape[1:-1], max_frame=None, filt=1, filt_size=spixel_size)
# smooth the discrete looking motion saliency maps.
motion_saliency_F_smooth = gaussian(motion_saliency_F, spixel_size/2.)
motion_saliency_B_smooth = gaussian(motion_saliency_B, spixel_size/2.)

# visualise the computed results. 
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15,15))
ax[0].imshow(movie[0], cmap='gray'); ax[0].grid('off'); ax[0].axis('off')
ax[1].imshow(movie[0], cmap='gray'); ax[1].grid('off'); ax[1].axis('off')
ax[0].set_title('Motion Sinks')
ax[1].set_title('Motion Sources')
ax[0].imshow(equalize_hist(motion_saliency_F_smooth), cmap='coolwarm', alpha=0.5, vmin=0, vmax=1)
ax[1].imshow(equalize_hist(motion_saliency_B_smooth), cmap='coolwarm', alpha=0.5, vmin=0, vmax=1)
#fig.savefig(os.path.join(saveanalysisfolder, 'motion-source_drosophila.svg'), bbox_inches='tight', dpi=300)
plt.show() 

